"""
Generate publication-quality figures for the gravity bilateral followup paper.

Figures:
  1. KAOPEN marginal effects plot
  2. Jackknife coefficient stability (forest plot)
  3. Net demographic reallocation pressure by country (2050)
  4. Coefficient comparison across specifications (forest plot)

Output: /output/figures/*.png at 300 dpi
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from pathlib import Path

# ── Paths ──────────────────────────────────────────────────────────────────
PROJECT = Path("/mnt/c/demographics_capital_flows/gravity_bilateral_followup")
TABLES = PROJECT / "output" / "tables"
FIGURES = PROJECT / "output" / "figures"
FIGURES.mkdir(parents=True, exist_ok=True)

# ── Style ──────────────────────────────────────────────────────────────────
plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "DejaVu Serif", "serif"],
    "font.size": 11,
    "axes.linewidth": 0.8,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.grid": False,
    "xtick.direction": "out",
    "ytick.direction": "out",
    "xtick.major.size": 4,
    "ytick.major.size": 4,
    "figure.dpi": 150,
    "savefig.dpi": 300,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.15,
})

# ── Load data ──────────────────────────────────────────────────────────────
gravity = pd.read_csv(TABLES / "gravity_results.csv")
jackknife = pd.read_csv(TABLES / "jackknife_results.csv")
projections = pd.read_csv(TABLES / "projection_summary_by_country.csv")
ppml = pd.read_csv(TABLES / "ppml_results.csv")
referee = pd.read_csv(TABLES / "referee_robustness.csv")
robustness = pd.read_csv(TABLES / "gravity_robustness.csv")


# ═══════════════════════════════════════════════════════════════════════════
# Figure 1: KAOPEN Marginal Effects
# ═══════════════════════════════════════════════════════════════════════════
def figure1_kaopen_marginal():
    """Marginal effect of dZ_1 on bilateral holdings as a function of KAOPEN_j."""
    # Model 2c coefficients
    m2c = gravity[gravity["model"] == "2c: Gravity + Demographics + KAOPEN interactions"]
    beta_dz1 = m2c.loc[m2c["variable"] == "dZ_1", "coefficient"].values[0]
    se_dz1 = m2c.loc[m2c["variable"] == "dZ_1", "std_error"].values[0]
    beta_int = m2c.loc[m2c["variable"] == "dZ_1_x_kaopen_j", "coefficient"].values[0]
    se_int = m2c.loc[m2c["variable"] == "dZ_1_x_kaopen_j", "std_error"].values[0]

    kaopen = np.linspace(0, 2.28, 200)
    marginal = beta_dz1 + beta_int * kaopen

    # Delta method SE (assuming Cov = 0 for now)
    se_marginal = np.sqrt(se_dz1**2 + kaopen**2 * se_int**2)
    ci_lo = marginal - 1.96 * se_marginal
    ci_hi = marginal + 1.96 * se_marginal

    # Find crossing point where lower CI crosses zero (becomes significant)
    sig_mask = ci_lo > 0
    if sig_mask.any():
        sig_idx = np.argmax(sig_mask)
        sig_kaopen = kaopen[sig_idx]
    else:
        sig_kaopen = None

    fig, ax = plt.subplots(figsize=(7, 4.5))
    ax.fill_between(kaopen, ci_lo, ci_hi, alpha=0.18, color="#2166ac", linewidth=0)
    ax.plot(kaopen, marginal, color="#2166ac", linewidth=2)
    ax.axhline(0, color="0.3", linewidth=0.7, linestyle="--", zorder=0)

    if sig_kaopen is not None:
        ax.axvline(sig_kaopen, color="#b2182b", linewidth=0.9, linestyle=":",
                   label=f"Significance threshold (KAOPEN = {sig_kaopen:.2f})")
        ax.plot(sig_kaopen, beta_dz1 + beta_int * sig_kaopen, "o",
                color="#b2182b", markersize=5, zorder=5)

    ax.set_xlabel("KAOPEN$_j$ (destination financial openness)", fontsize=12)
    ax.set_ylabel("Marginal effect of $\\Delta Z_1$ on log(bilateral position)", fontsize=12)
    ax.set_title("Marginal Effect of Demographic Distance\non Bilateral Portfolio Holdings",
                 fontsize=13, fontweight="bold", pad=12)
    if sig_kaopen is not None:
        ax.legend(loc="upper left", frameon=False, fontsize=9.5)
    ax.set_xlim(0, 2.28)

    fig.savefig(FIGURES / "fig1_kaopen_marginal_effects.png")
    plt.close(fig)
    print("Figure 1 saved.")


# ═══════════════════════════════════════════════════════════════════════════
# Figure 2: Jackknife Coefficient Stability
# ═══════════════════════════════════════════════════════════════════════════
def figure2_jackknife():
    """Leave-one-region-out jackknife: dZ_1 coefficient stability."""
    jk = jackknife[jackknife["variable"] == "dZ_1"].copy()
    jk = jk.sort_values("coefficient")

    # Full-sample estimate
    full_coef = 0.815
    full_se = 0.646

    # Colour by significance
    colors = ["#2166ac" if p < 0.05 else "#b2182b" if p < 0.10 else "0.55"
              for p in jk["p_value"]]
    # Approximate SE from coefficient and p-value (two-sided normal)
    from scipy import stats
    jk["approx_se"] = np.abs(jk["coefficient"]) / np.maximum(
        np.abs(stats.norm.ppf(jk["p_value"] / 2)), 0.01)

    fig, ax = plt.subplots(figsize=(7, 5.5))
    y_pos = np.arange(len(jk))
    labels = jk["excluded_region"].values

    ax.barh(y_pos, jk["coefficient"].values, height=0.55, color=colors, alpha=0.85,
            edgecolor="white", linewidth=0.5)
    # Error bars
    ax.errorbar(jk["coefficient"].values, y_pos,
                xerr=1.96 * jk["approx_se"].values,
                fmt="none", ecolor="0.3", elinewidth=0.8, capsize=2.5)

    ax.axvline(full_coef, color="0.2", linewidth=1.2, linestyle="--",
               label=f"Full-sample estimate ({full_coef:.3f})")
    ax.axvline(0, color="0.5", linewidth=0.6, linestyle="-")

    ax.set_yticks(y_pos)
    ax.set_yticklabels([f"Excl. {r}" for r in labels], fontsize=9.5)
    ax.set_xlabel("$\\Delta Z_1$ coefficient", fontsize=12)
    ax.set_title("Leave-One-Region-Out Jackknife: $\\Delta Z_1$ Coefficient",
                 fontsize=13, fontweight="bold", pad=12)
    ax.legend(loc="lower right", frameon=False, fontsize=9.5)

    # Custom legend entries for colours
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor="#2166ac", label="p < 0.05"),
        Patch(facecolor="#b2182b", label="0.05 < p < 0.10"),
        Patch(facecolor="0.55", label="p > 0.10"),
        plt.Line2D([0], [0], color="0.2", linewidth=1.2, linestyle="--",
                    label=f"Full sample ({full_coef:.3f})"),
    ]
    ax.legend(handles=legend_elements, loc="lower right", frameon=False, fontsize=9)

    fig.savefig(FIGURES / "fig2_jackknife_stability.png")
    plt.close(fig)
    print("Figure 2 saved.")


# ═══════════════════════════════════════════════════════════════════════════
# Figure 3: Net Demographic Reallocation Pressure by Country (2050)
# ═══════════════════════════════════════════════════════════════════════════
def figure3_reallocation():
    """Horizontal bar chart of projected net demographic reallocation pressure."""
    df = projections.sort_values("net_pressure", ascending=True)

    colors = ["#b2182b" if v > 0 else "#2166ac" for v in df["net_pressure"]]

    fig, ax = plt.subplots(figsize=(7, 7))
    y_pos = np.arange(len(df))
    ax.barh(y_pos, df["net_pressure"].values, height=0.65, color=colors, alpha=0.85,
            edgecolor="white", linewidth=0.5)
    ax.axvline(0, color="0.3", linewidth=0.7)

    ax.set_yticks(y_pos)
    ax.set_yticklabels(df["iso3"].values, fontsize=9)
    ax.set_xlabel("Net reallocation pressure (outward $-$ received)", fontsize=11)
    ax.set_title("Projected Net Demographic Reallocation\nPressure by 2050",
                 fontsize=13, fontweight="bold", pad=12)

    # Annotations for direction
    ax.text(0.97, 0.97, "Aging $\\rightarrow$ outward",
            transform=ax.transAxes, ha="right", va="top",
            fontsize=9, color="#b2182b", fontstyle="italic")
    ax.text(0.03, 0.03, "Young $\\rightarrow$ received",
            transform=ax.transAxes, ha="left", va="bottom",
            fontsize=9, color="#2166ac", fontstyle="italic")

    fig.savefig(FIGURES / "fig3_reallocation_pressure_2050.png")
    plt.close(fig)
    print("Figure 3 saved.")


# ═══════════════════════════════════════════════════════════════════════════
# Figure 4: Coefficient Comparison (Forest Plot) Across Specifications
# ═══════════════════════════════════════════════════════════════════════════
def figure4_forest():
    """Forest plot of dZ_1 coefficient across key specifications."""
    specs = []

    def _add(label, coef, se, p):
        specs.append({"label": label, "coef": coef, "se": se, "p": p})

    # Model 2b (pooled GLS)
    r = gravity[(gravity["model"] == "2b: Gravity + Demographics") &
                (gravity["variable"] == "dZ_1")]
    _add("Model 2b (pooled GLS)", r["coefficient"].values[0],
         r["std_error"].values[0], r["p_value"].values[0])

    # Model 2c (level + KAOPEN)
    r = gravity[(gravity["model"] == "2c: Gravity + Demographics + KAOPEN interactions") &
                (gravity["variable"] == "dZ_1")]
    _add("Model 2c (+ KAOPEN)", r["coefficient"].values[0],
         r["std_error"].values[0], r["p_value"].values[0])

    # Portfolio Debt
    r = gravity[(gravity["model"] == "2d: Portfolio Debt") &
                (gravity["variable"] == "dZ_1")]
    _add("Portfolio Debt", r["coefficient"].values[0],
         r["std_error"].values[0], r["p_value"].values[0])

    # PPML 2b
    r = ppml[(ppml["model"].str.contains("PPML 2b")) &
             (ppml["variable"] == "dZ_1")]
    _add("PPML 2b", r["coefficient"].values[0],
         r["std_error"].values[0], r["p_value"].values[0])

    # Pair FE 2b
    r = referee[(referee["model"] == "2b: Pair FE + Year FE") &
                (referee["variable"] == "dZ_1")]
    # std_error column may have the FE SE (use the one available)
    se_val = r["std_error"].values[0] if not pd.isna(r["std_error"].values[0]) else r["std_error_gls"].values[0]
    _add("Pair FE 2b", r["coefficient"].values[0], se_val, r["p_value"].values[0])

    # Pair FE 2c
    r = referee[(referee["model"] == "2c: Pair FE + Year FE") &
                (referee["variable"] == "dZ_1")]
    se_val = r["std_error"].values[0] if not pd.isna(r["std_error"].values[0]) else r["std_error_gls"].values[0]
    _add("Pair FE 2c", r["coefficient"].values[0], se_val, r["p_value"].values[0])

    # Excl Advanced Europe
    r = robustness[(robustness["model"] == "3b: Excl Advanced Europe") &
                   (robustness["variable"] == "dZ_1")]
    _add("Excl Adv. Europe", r["coefficient"].values[0],
         r["std_error"].values[0], r["p_value"].values[0])

    # Excl MENA
    r = robustness[(robustness["model"] == "3b: Excl Middle East & North Africa") &
                   (robustness["variable"] == "dZ_1")]
    _add("Excl MENA", r["coefficient"].values[0],
         r["std_error"].values[0], r["p_value"].values[0])

    # Excl SSA
    r = robustness[(robustness["model"] == "3b: Excl Sub-Saharan Africa") &
                   (robustness["variable"] == "dZ_1")]
    _add("Excl SSA", r["coefficient"].values[0],
         r["std_error"].values[0], r["p_value"].values[0])

    df = pd.DataFrame(specs)
    df = df.iloc[::-1]  # reverse for top-to-bottom display

    fig, ax = plt.subplots(figsize=(7, 5))
    y_pos = np.arange(len(df))

    # Colour by significance
    colors = []
    for _, row in df.iterrows():
        if row["p"] < 0.01:
            colors.append("#2166ac")
        elif row["p"] < 0.05:
            colors.append("#4393c3")
        elif row["p"] < 0.10:
            colors.append("#d6604d")
        else:
            colors.append("0.55")

    ax.errorbar(df["coef"].values, y_pos,
                xerr=1.96 * df["se"].values,
                fmt="o", markersize=6, color="0.2",
                ecolor="0.4", elinewidth=1.2, capsize=3.5,
                linewidth=0, zorder=5)
    # Colour the markers
    for i, (c, y) in enumerate(zip(df["coef"].values, y_pos)):
        ax.plot(c, y, "o", markersize=7, color=colors[i], zorder=6)

    ax.axvline(0, color="0.5", linewidth=0.7, linestyle="-")
    # Reference: full-sample 2b estimate
    ax.axvline(0.815, color="0.2", linewidth=0.9, linestyle="--", alpha=0.5,
               label="Full-sample 2b (0.815)")

    ax.set_yticks(y_pos)
    ax.set_yticklabels(df["label"].values, fontsize=10)
    ax.set_xlabel("$\\Delta Z_1$ coefficient (95% CI)", fontsize=12)
    ax.set_title("Coefficient Comparison Across Specifications",
                 fontsize=13, fontweight="bold", pad=12)

    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor="#2166ac", label="p < 0.01"),
        Patch(facecolor="#4393c3", label="p < 0.05"),
        Patch(facecolor="#d6604d", label="p < 0.10"),
        Patch(facecolor="0.55", label="p > 0.10"),
        plt.Line2D([0], [0], color="0.2", linewidth=0.9, linestyle="--",
                    alpha=0.5, label="Full-sample 2b"),
    ]
    ax.legend(handles=legend_elements, loc="upper left", frameon=False, fontsize=8.5)

    fig.savefig(FIGURES / "fig4_coefficient_forest.png")
    plt.close(fig)
    print("Figure 4 saved.")


# ── Run all ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    figure1_kaopen_marginal()
    figure2_jackknife()
    figure3_reallocation()
    figure4_forest()
    print(f"\nAll figures saved to {FIGURES}")
